{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp callback.noisy_student"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Noisy student"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Callback to apply noisy student self-training (a semi-supervised learning approach) based on: \n",
    "\n",
    "Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). \n",
    "<span style=\"color:dodgerblue\">Self-training with noisy student improves imagenet classification</span>. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10687-10698)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "import torch.multiprocessing\n",
    "torch.multiprocessing.set_sharing_strategy('file_system')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from tsai.utils import *\n",
    "from tsai.data.preprocessing import *\n",
    "from tsai.data.transforms import *\n",
    "from tsai.models.layers import *\n",
    "from fastai.callback.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "\n",
    "# This is an unofficial implementation of noisy student based on:\n",
    "# Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification.\n",
    "# In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10687-10698).\n",
    "# Official tensorflow implementation available in https://github.com/google-research/noisystudent\n",
    "\n",
    "\n",
    "class NoisyStudent(Callback):\n",
    "    \"\"\"A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise:\n",
    "        - stochastic depth: .8\n",
    "        - RandAugment: N=2, M=27\n",
    "        - dropout: .5\n",
    "\n",
    "    Steps:\n",
    "        1. Build the dl you will use as a teacher\n",
    "        2. Create dl2 with the pseudolabels (either soft or hard preds)\n",
    "        3. Pass any required batch_tfms to the callback\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True,\n",
    "                 pseudolabel_sample_weight:float=1., verbose=False):\n",
    "        r'''\n",
    "        Args:\n",
    "            dl2:                       dataloader with the pseudolabels\n",
    "            bs:                        batch size of the new, combined dataloader. If None, it will pick the bs from the labeled dataloader.\n",
    "            l2pl_ratio:                ratio between labels and pseudolabels in the combined batch\n",
    "            batch_tfms:                transforms applied to the combined batch. If None, it will pick the batch_tfms from the labeled dataloader (if any)\n",
    "            do_setup:                  perform a transform setup on the labeled dataset.\n",
    "            pseudolabel_sample_weight: weight of each pseudolabel sample relative to the labeled one of the loss.\n",
    "        '''\n",
    "\n",
    "        self.dl2, self.bs, self.l2pl_ratio, self.batch_tfms, self.do_setup, self.verbose = dl2, bs, l2pl_ratio, batch_tfms, do_setup, verbose\n",
    "        self.pl_sw = pseudolabel_sample_weight\n",
    "\n",
    "    def before_fit(self):\n",
    "        if self.batch_tfms is None: self.batch_tfms = self.dls.train.after_batch\n",
    "        self.old_bt = self.dls.train.after_batch # Remove and store dl.train.batch_tfms\n",
    "        self.old_bs = self.dls.train.bs\n",
    "        self.dls.train.after_batch = noop\n",
    "\n",
    "        if self.do_setup and self.batch_tfms:\n",
    "            for bt in self.batch_tfms:\n",
    "                bt.setup(self.dls.train)\n",
    "\n",
    "        if self.bs is None: self.bs = self.dls.train.bs\n",
    "        self.dl2.to(self.dls.device)\n",
    "        self.dl2.bs = min(len(self.dl2.dataset), int(self.bs / (1 + self.l2pl_ratio)))\n",
    "        self.dls.train.bs = self.bs - self.dl2.bs\n",
    "        pv(f'labels / pseudolabels per training batch              : {self.dls.train.bs} / {self.dl2.bs}', self.verbose)\n",
    "        rel_weight = (self.dls.train.bs/self.dl2.bs) * (len(self.dl2.dataset)/len(self.dls.train.dataset))\n",
    "        pv(f'relative labeled/ pseudolabel sample weight in dataset: {rel_weight:.1f}', self.verbose)\n",
    "\n",
    "        self.dl2iter = iter(self.dl2)\n",
    "\n",
    "        self.old_loss_func = self.learn.loss_func\n",
    "        self.learn.loss_func = self.loss\n",
    "\n",
    "    def before_batch(self):\n",
    "        if self.training:\n",
    "            X, y = self.x, self.y\n",
    "            try: X2, y2 = next(self.dl2iter)\n",
    "            except StopIteration:\n",
    "                self.dl2iter = iter(self.dl2)\n",
    "                X2, y2 = next(self.dl2iter)\n",
    "            if y.ndim == 1 and y2.ndim == 2: y = torch.eye(self.learn.dls.c, device=y.device)[y]\n",
    "\n",
    "            X_comb, y_comb = concat(X, X2), concat(y, y2)\n",
    "\n",
    "            if self.batch_tfms is not None:\n",
    "                X_comb = compose_tfms(X_comb, self.batch_tfms, split_idx=0)\n",
    "                y_comb = compose_tfms(y_comb, self.batch_tfms, split_idx=0)\n",
    "            self.learn.xb = (X_comb,)\n",
    "            self.learn.yb = (y_comb,)\n",
    "            pv(f'\\nX: {X.shape}  X2: {X2.shape}  X_comb: {X_comb.shape}', self.verbose)\n",
    "            pv(f'y: {y.shape}  y2: {y2.shape}  y_comb: {y_comb.shape}', self.verbose)\n",
    "\n",
    "    def loss(self, output, target):\n",
    "        if target.ndim == 2: _, target = target.max(dim=1)\n",
    "        if self.training and self.pl_sw != 1:\n",
    "            loss = (1 - self.pl_sw) * self.old_loss_func(output[:self.dls.train.bs], target[:self.dls.train.bs])\n",
    "            loss += self.pl_sw * self.old_loss_func(output[self.dls.train.bs:], target[self.dls.train.bs:])\n",
    "            return loss\n",
    "        else:\n",
    "            return self.old_loss_func(output, target)\n",
    "\n",
    "    def after_fit(self):\n",
    "        self.dls.train.after_batch = self.old_bt\n",
    "        self.learn.loss_func = self.old_loss_func\n",
    "        self.dls.train.bs = self.old_bs\n",
    "        self.dls.bs = self.old_bs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.data.all import *\n",
    "from tsai.models.all import *\n",
    "from tsai.tslearner import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsid = 'NATOPS'\n",
    "X, y, splits = get_UCR_data(dsid, return_split=False)\n",
    "X = X.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "labels / pseudolabels per training batch              : 171 / 85\n",
      "relative labeled/ pseudolabel sample weight in dataset: 4.0\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.782144</td>\n",
       "      <td>1.758471</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "X: torch.Size([171, 24, 51])  X2: torch.Size([85, 24, 51])  X_comb: torch.Size([256, 24, 41])\n",
      "y: torch.Size([171])  y2: torch.Size([85])  y_comb: torch.Size([256])\n"
     ]
    }
   ],
   "source": [
    "pseudolabeled_data = X\n",
    "soft_preds = True\n",
    "\n",
    "pseudolabels = ToNumpyCategory()(y) if soft_preds else OneHot()(y)\n",
    "dsets2 = TSDatasets(pseudolabeled_data, pseudolabels)\n",
    "dl2 = TSDataLoader(dsets2, num_workers=0)\n",
    "noisy_student_cb = NoisyStudent(dl2, bs=256, l2pl_ratio=2, verbose=True)\n",
    "tfms = [None, TSClassification]\n",
    "learn = TSClassifier(X, y, splits=splits, tfms=tfms, batch_tfms=[TSStandardize(), TSRandomSize(.5)], cbs=noisy_student_cb)\n",
    "learn.fit_one_cycle(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "labels / pseudolabels per training batch              : 171 / 85\n",
      "relative labeled/ pseudolabel sample weight in dataset: 4.0\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
       "        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.898401</td>\n",
       "      <td>1.841182</td>\n",
       "      <td>0.155556</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "X: torch.Size([171, 24, 51])  X2: torch.Size([85, 24, 51])  X_comb: torch.Size([256, 24, 51])\n",
      "y: torch.Size([171, 6])  y2: torch.Size([85, 6])  y_comb: torch.Size([256, 6])\n"
     ]
    }
   ],
   "source": [
    "pseudolabeled_data = X\n",
    "soft_preds = False\n",
    "\n",
    "pseudolabels = ToNumpyCategory()(y) if soft_preds else OneHot()(y)\n",
    "pseudolabels = pseudolabels.astype(np.float32)\n",
    "dsets2 = TSDatasets(pseudolabeled_data, pseudolabels)\n",
    "dl2 = TSDataLoader(dsets2, num_workers=0)\n",
    "noisy_student_cb = NoisyStudent(dl2, bs=256, l2pl_ratio=2, verbose=True)\n",
    "tfms = [None, TSClassification]\n",
    "learn = TSClassifier(X, y, splits=splits, tfms=tfms, batch_tfms=[TSStandardize(), TSRandomSize(.5)], cbs=noisy_student_cb)\n",
    "learn.fit_one_cycle(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/026_callback.noisy_student.ipynb saved at 2024-02-10 21:53:24\n",
      "Correct notebook to script conversion! 😃\n",
      "Saturday 10/02/24 21:53:27 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
